{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-22T11:32:57.784218Z",
     "start_time": "2020-09-22T11:32:57.781673Z"
    }
   },
   "outputs": [],
   "source": [
    "#!sudo pip3 install \"foolbox==3.1.1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-22T11:32:58.767973Z",
     "start_time": "2020-09-22T11:32:57.786360Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.4.0+cu100\n",
      "0.5.0+cu100\n",
      "3.1.1\n",
      "0.29.0\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "import torch\n",
    "import torchvision\n",
    "import foolbox\n",
    "import eagerpy as ep\n",
    "print(torch.__version__)\n",
    "print(torchvision.__version__)\n",
    "print(foolbox.__version__)\n",
    "print(ep.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-22T11:32:59.543530Z",
     "start_time": "2020-09-22T11:32:58.770868Z"
    }
   },
   "outputs": [],
   "source": [
    "model = torchvision.models.resnet18(pretrained=True).eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-22T11:33:02.186293Z",
     "start_time": "2020-09-22T11:32:59.546929Z"
    }
   },
   "outputs": [],
   "source": [
    "preprocessing = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], axis=-3)\n",
    "fmodel = foolbox.PyTorchModel(model, bounds=(0, 1), preprocessing=preprocessing)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-22T11:33:02.484611Z",
     "start_time": "2020-09-22T11:33:02.188204Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([16, 3, 224, 224]) torch.Size([16])\n"
     ]
    }
   ],
   "source": [
    "# get data and test the model\n",
    "# wrapping the tensors with ep.astensors is optional, but it allows\n",
    "# us to work with EagerPy tensors in the following\n",
    "images, labels = ep.astensors(*foolbox.samples(fmodel, dataset=\"imagenet\", batchsize=16))\n",
    "print(images.shape, labels.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-22T11:33:02.508440Z",
     "start_time": "2020-09-22T11:33:02.490161Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9375\n"
     ]
    }
   ],
   "source": [
    "print(foolbox.accuracy(fmodel, images, labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-22T11:33:09.191417Z",
     "start_time": "2020-09-22T11:33:02.510690Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8.21 ms ± 54.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
     ]
    }
   ],
   "source": [
    "%timeit foolbox.accuracy(fmodel, images, labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-22T11:33:12.460870Z",
     "start_time": "2020-09-22T11:33:09.194373Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3.99 ms ± 131 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
     ]
    }
   ],
   "source": [
    "%timeit foolbox.accuracy(fmodel, images[:1], labels[:1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-22T11:33:13.524389Z",
     "start_time": "2020-09-22T11:33:12.463266Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "attack took 1.1 seconds\n"
     ]
    }
   ],
   "source": [
    "# apply the attack\n",
    "start = time.time()\n",
    "attack = foolbox.attacks.LinfPGD()\n",
    "epsilons = [0.002]\n",
    "advs, _, success = attack(fmodel, images, labels, epsilons=epsilons)\n",
    "print(f\"attack took {time.time() - start:.1f} seconds\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-22T11:33:13.607288Z",
     "start_time": "2020-09-22T11:33:13.526656Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.002 0.0625\n"
     ]
    }
   ],
   "source": [
    "# calculate and report the robust accuracy\n",
    "robust_accuracy = 1 - success.float32().mean(axis=-1)\n",
    "for eps, acc in zip(epsilons, robust_accuracy):\n",
    "    print(eps, acc.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-22T11:33:22.619894Z",
     "start_time": "2020-09-22T11:33:13.609361Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "attack took 9.0 seconds\n"
     ]
    }
   ],
   "source": [
    "# apply the attack\n",
    "start = time.time()\n",
    "attack = foolbox.attacks.LinfPGD()\n",
    "epsilons = [0.0, 0.001, 0.01, 0.03, 0.1, 0.3, 0.5, 1.0]\n",
    "advs, _, success = attack(fmodel, images, labels, epsilons=epsilons)\n",
    "print(f\"attack took {time.time() - start:.1f} seconds\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-22T11:33:22.701509Z",
     "start_time": "2020-09-22T11:33:22.622579Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0 0.9375\n",
      "0.001 0.375\n",
      "0.01 0.0\n",
      "0.03 0.0\n",
      "0.1 0.0\n",
      "0.3 0.0\n",
      "0.5 0.0\n",
      "1.0 0.0\n"
     ]
    }
   ],
   "source": [
    "# calculate and report the robust accuracy\n",
    "robust_accuracy = 1 - success.float32().mean(axis=-1)\n",
    "for eps, acc in zip(epsilons, robust_accuracy):\n",
    "    print(eps, acc.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-22T11:33:22.705849Z",
     "start_time": "2020-09-22T11:33:22.703489Z"
    }
   },
   "outputs": [],
   "source": [
    "## save images and labels for comparison with classic Foolbox\n",
    "# import numpy as np\n",
    "# np.save(\"images.npy\", images.numpy())\n",
    "# np.save(\"labels.npy\", labels.numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
